background mask option for pseudo labeler - unsupervised training#135
background mask option for pseudo labeler - unsupervised training#135stmartineau99 wants to merge 14 commits intocomputational-cell-analytics:mainfrom
Conversation
constantinpape
left a comment
There was a problem hiding this comment.
This looks very good now! The only things to address:
- You are not actually using the new class for the mean teacher trainer.
- Some cosmetics.
See details in the comments.
|
I think some changes for these files did not get merged to the main branch:
|
|
I added gpu device arguments to some other files, and my personal training scripts. |
| # Sample from both the supervised and unsupervised loader. | ||
| for xu1, xu2 in self.unsupervised_train_loader: | ||
|
|
||
| # Assuming shape (B, C, D, H, W), only keep the first channel for xu2 (student input). |
There was a problem hiding this comment.
Can you please add the assert here or throw some other kind of error if we have a wrong number of channels @stmartineau99 ? See previous comment.
|
|
||
| Args: | ||
| data_paths: The filepaths to the hdf5 files containing the training data. | ||
| data_paths: The filepaths to the mrc files containing the training data. |
There was a problem hiding this comment.
| data_paths: The filepaths to the mrc files containing the training data. | |
| data_paths: The filepaths to the mrc or hdf5 files containing the training data. |
| in_channels: int = 1, | ||
| out_channels: int = 2, | ||
| mask_channel: bool = False, | ||
| device: int = 0, |
There was a problem hiding this comment.
I think that it's better to pass a torch.device here that is optional:
device: Optional[torch.device] = NoneThen you don't need any further changes below.
| x = f(x) | ||
| return x | ||
|
|
||
| class ChannelWiseAugmentations: |
There was a problem hiding this comment.
This class is creating an issue with serialization in Case 1 only (see below) because it is called below like: augmentations = (ChannelWiseAugmentations(weak_augmentations()), ChannelWiseAugmentations(weak_augmentations()))
I rewrote the class to avoid having nested function calls inside get_unsupervised_loader() but this same error still happens.
stmartineau99
left a comment
There was a problem hiding this comment.
- mainly changes to semisupervised_training.py - get_unsupervised_loader()
- changes to domain_adaptation.py - responded to your last comment about asserting dims.
NewPseudoLabelerto subtract background from the pseudo labels if a background mask is givenget_unsupervised_loaderto handle 4 cases:get_unsupervised_loaderhandles each case correctly:DropChannel,ChannelWiseRawTransform,ChannelWiseAugmentationsmean_teacher_adaptationto use theNewPseudoLabelerif the background mask is givenNewMeanTeacherTrainer, which drops the background mask from the teacher input after computing pseudo labels. it also drops the background mask from the student input since this is not used. This behavior only occurs when it recieves training data, and not for validation.NewPseudoLabelerso that it behaves correctly when it recieves training and validation data, since in the case of validation there is no background mask channel.